import random
import pickle
import os
import bisect
import gymnasium as gym
import numpy as np

from models.actionRobustQ import ActionRobustQ
from models.arrlc import ARRLC
from models.orlc import ORLC



def get_agent(args, state_dim, action_dim):
    if args.agent_type == "rq":
        return ActionRobustQ(n_state=state_dim,
                              n_action=action_dim,
                              rho=args.rho,
                              epsilon=args.epsilon,
                              alpha=args.alpha,
                              gamma=args.gamma)
    elif args.agent_type == "arrlc":
        return ARRLC(n_state=state_dim,
                      n_action=action_dim,
                      n_episode=args.num_episodes,
                      n_step=args.episode_steps_train,
                      rho=args.rho,
                      iota=args.iota,
                      const=args.const)
    elif args.agent_type == "orlc":
        return ORLC(n_state=state_dim,
                     n_action=action_dim,
                     n_episode=args.num_episodes,
                     n_step=args.episode_steps_train,
                     rho=args.rho,
                     iota=args.iota,
                     const=args.const)
    else:
        raise ValueError("Only support [rq, orlc, arrlc]")



def get_env(args):
    if args.env_name == "FrozenLake-v1":
        return gym.make(args.env_name, is_slippery=True)
    elif args.env_name in ["Taxi-v3", "CliffWalking-v0", "CartPole-v1", "InvertedPendulum-v4", "MountainCar-v0"]:
        return gym.make(args.env_name)
    else:
        raise ValueError("Only support [FrozenLake-v1, Taxi-v3, CliffWalking-v0, CartPole-v1, InvertedPendulum-v4, MountainCar-v0]")


def normalize_reward(r: float, args) -> float:
    if args.agent_type in ["rq", "rqh"]:
        return r
    else:
        if args.env_name in ["FrozenLake-v1", "CartPole-v1", "InvertedPendulum-v4"]:
            return r
        elif "Taxi" in args.env_name:
            return (r + 10) / 30
        elif "CliffWalking" in args.env_name:
            return (r + 100) / 100
        elif args.env_name == "MountainCar-v0":
            return r + 1
        else:
            raise ValueError("Only support [FrozenLake, Taxi, CliffWalking, CartPole, InvertedPendulum, MountainCar]")


def take_env_step(env: gym.Env,
                  action: int,
                  already_done: bool,
                  prev_state: int,
                  env_name: str):
    if env_name == "InvertedPendulum-v4":
        ACTION_LIST = [-0.5, -0.2, -0.1, 0, 0.1, 0.2, 0.5]
        action = [ACTION_LIST[action] * 0.5]

    if already_done and env_name in ["CliffWalking-v0", "CartPole-v1", "InvertedPendulum-v4", "MountainCar-v0"]:
        next_state = prev_state
        reward = 0
        done = True
    else:
        next_state, reward, done, _, _ = env.step(action)

    return next_state, reward, done


def perturb_action(args,
                   action: int,
                   num_actions: int) -> int:
    if random.random() > args.p:
        return action
    else:
        if args.perturb_type == "random":
            return random.choice(list(range(0, num_actions)))
        elif args.perturb_type == "fix":
            if args.env_name == "CliffWalking-v0":
                return 2
            elif args.env_name == "InvertedPendulum-v4":
                return 0
            else:
                raise ValueError()



def discretize_state(state, env_name: str) -> int:
    if "CartPole" in env_name:
        feature1, feature2, feature3, feature4 = state

        s1 = max(min((feature1 // 1.2) + 3, 5), 0)
        s2 = max(min((feature2 // 0.5) + 3, 5), 0)
        s3 = max(min((feature3 // 0.1045) + 3, 5), 0)
        s4 = max(min((feature4 // 0.5) + 3, 5), 0)

        return int(s1 + s2 * 6 + s3 * 6**2 + s4 * 6**3)

    elif "InvertedPendulum-v4" in env_name:
        feature1, feature2, feature3, feature4 = state

        S2 = [-0.2, -0.1, -0.05, -0.025, 0, 0.025, 0.05, 0.1, 0.2]

        # s1 = max(min((feature1 // 0.5) + 2, 3), 0)
        s2 = bisect.bisect_left(S2, feature2)
        s3 = max(min((feature3 // 0.05) + 2, 3), 0)
        s4 = max(min((feature4 // 0.05) + 2, 3), 0)

        # return int(s1 + s2 * 4 + s3 * 4 ** 2 + s4 * 4 ** 3)
        return int(s2 + s3 * 10 + s4 * 10 * 4)

    elif "MountainCar-v0" in env_name:
        feature1, feature2 = state

        s1 = max(min((feature1 // 0.1) + 11, 16), 0)
        s2 = max(min((feature2 // 0.007) + 9, 19), 0)

        return int(s1 + s2 * 18)


    else:
        return state



def save_arrlc_model(agent, output_dir: str):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    with open("{}/agent.pkl".format(output_dir), "wb") as f:
        pickle.dump(agent, f)


def load_arrlc_model(model_dir: str):
    with open("{}/agent.pkl".format(model_dir), "rb") as f:
        agent_dict = pickle.load(f)
    
    agent = ARRLC()
    agent.__dict__ = agent_dict
    return agent


